{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4weyZUFfgUlD"
      },
      "source": [
        "# StableLM-Alpha\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Stability-AI/StableLM//blob/main/notebooks/stablelm-alpha.ipynb)\n",
        "\n",
        "<img src=\"https://raw.githubusercontent.com/Stability-AI/StableLM/main/assets/mascot.png?token=GHSAT0AAAAAABWTZAV7EFSADKXWO3HDNKPYZBZ6Z7A\"/>\n",
        "\n",
        "This notebook is designed to let you quickly generate text with the latest StableLM models (**StableLM-Alpha**) using Hugging Face's `transformers` library."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8xicyuk_Ezuw"
      },
      "outputs": [],
      "source": [
        "!nvidia-smi"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V1Da2YDX71IF"
      },
      "outputs": [],
      "source": [
        "!pip install -U pip\n",
        "!pip install accelerate bitsandbytes torch transformers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "cellView": "form",
        "id": "sSifeGXKlIgY"
      },
      "outputs": [],
      "source": [
        "#@title Setup\n",
        "\n",
        "import torch\n",
        "from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList\n",
        "\n",
        "from IPython.display import Markdown, display\n",
        "def hr(): display(Markdown('---'))\n",
        "def cprint(msg: str, color: str = \"blue\", **kwargs) -> str:\n",
        "    if color == \"blue\": print(\"\\033[34m\" + msg + \"\\033[0m\", **kwargs)\n",
        "    elif color == \"red\": print(\"\\033[31m\" + msg + \"\\033[0m\", **kwargs)\n",
        "    elif color == \"green\": print(\"\\033[32m\" + msg + \"\\033[0m\", **kwargs)\n",
        "    elif color == \"yellow\": print(\"\\033[33m\" + msg + \"\\033[0m\", **kwargs)\n",
        "    elif color == \"purple\": print(\"\\033[35m\" + msg + \"\\033[0m\", **kwargs)\n",
        "    elif color == \"cyan\": print(\"\\033[36m\" + msg + \"\\033[0m\", **kwargs)\n",
        "    else: raise ValueError(f\"Invalid info color: `{color}`\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "dQZCeE-ujdzW"
      },
      "outputs": [],
      "source": [
        "#@title Pick Your Model\n",
        "#@markdown Refer to Hugging Face docs for more information the parameters below: https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained\n",
        "\n",
        "# Choose model name\n",
        "model_name = \"stabilityai/stablelm-tuned-alpha-7b\" #@param [\"stabilityai/stablelm-tuned-alpha-7b\", \"stabilityai/stablelm-base-alpha-7b\", \"stabilityai/stablelm-tuned-alpha-3b\", \"stabilityai/stablelm-base-alpha-3b\"]\n",
        "\n",
        "cprint(f\"Using `{model_name}`\", color=\"blue\")\n",
        "\n",
        "# Select \"big model inference\" parameters\n",
        "torch_dtype = \"float16\" #@param [\"float16\", \"bfloat16\", \"float\"]\n",
        "load_in_8bit = False #@param {type:\"boolean\"}\n",
        "device_map = \"auto\"\n",
        "\n",
        "cprint(f\"Loading with: `{torch_dtype=}, {load_in_8bit=}, {device_map=}`\")\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_name,\n",
        "    torch_dtype=getattr(torch, torch_dtype),\n",
        "    load_in_8bit=load_in_8bit,\n",
        "    device_map=device_map,\n",
        "    offload_folder=\"./offload\",\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 35,
      "metadata": {
        "cellView": "form",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 327
        },
        "id": "P01Db-SVwtPO",
        "outputId": "9911dead-44b8-43e2-de73-c40857131065"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[34mSampling with: `max_new_tokens=128, temperature=0.7, top_k=0, top_p=0.9, do_sample=True`\u001b[0m\n"
          ]
        },
        {
          "data": {
            "text/markdown": [
              "---"
            ],
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Can you write a song about a pirate at sea? \u001b[32mSure, here's a song about a pirate at sea:\n",
            "\n",
            "Verse 1:\n",
            "There he was, a pirate so bold\n",
            "Sailing the seas, his story untold\n",
            "His name was Captain Jack, and he ruled the waves\n",
            "A legend in the seas, he conquered all his foes\n",
            "\n",
            "Chorus:\n",
            "Oh, Captain Jack, the pirate of the sea\n",
            "Your bravery and your daring, set us all free\n",
            "From the tyranny of the sea, you led us to glory\n",
            "A legend in our hearts, you'll be remembered as our story\n",
            "\n",
            "Verse 2:\n",
            "He sailed the\u001b[0m\n"
          ]
        }
      ],
      "source": [
        "#@title Generate Text\n",
        "#@markdown <b>Note: The model response is colored in green</b>\n",
        "\n",
        "class StopOnTokens(StoppingCriteria):\n",
        "    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:\n",
        "        stop_ids = [50278, 50279, 50277, 1, 0]\n",
        "        for stop_id in stop_ids:\n",
        "            if input_ids[0][-1] == stop_id:\n",
        "                return True\n",
        "        return False\n",
        "\n",
        "# Process the user prompt\n",
        "user_prompt = \"Can you write a song about a pirate at sea?\" #@param {type:\"string\"}\n",
        "if \"tuned\" in model_name:\n",
        "  # Add system prompt for chat tuned models\n",
        "  system_prompt = \"\"\"<|SYSTEM|># StableLM Tuned (Alpha version)\n",
        "  - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.\n",
        "  - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n",
        "  - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.\n",
        "  - StableLM will refuse to participate in anything that could harm a human.\n",
        "  \"\"\"\n",
        "  prompt = f\"{system_prompt}<|USER|>{user_prompt}<|ASSISTANT|>\"\n",
        "else:\n",
        "  prompt = user_prompt\n",
        "\n",
        "# Sampling args\n",
        "max_new_tokens = 128 #@param {type:\"slider\", min:32.0, max:3072.0, step:32}\n",
        "temperature = 0.7 #@param {type:\"slider\", min:0.0, max:1.25, step:0.05}\n",
        "top_k = 0 #@param {type:\"slider\", min:0.0, max:1.0, step:0.05}\n",
        "top_p = 0.9 #@param {type:\"slider\", min:0.0, max:1.0, step:0.05}\n",
        "do_sample = True #@param {type:\"boolean\"}\n",
        "\n",
        "cprint(f\"Sampling with: `{max_new_tokens=}, {temperature=}, {top_k=}, {top_p=}, {do_sample=}`\")\n",
        "hr()\n",
        "\n",
        "# Create `generate` inputs\n",
        "inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
        "inputs.to(model.device)\n",
        "\n",
        "# Generate\n",
        "tokens = model.generate(\n",
        "  **inputs,\n",
        "  max_new_tokens=max_new_tokens,\n",
        "  temperature=temperature,\n",
        "  top_k=top_k,\n",
        "  top_p=top_p,\n",
        "  do_sample=do_sample,\n",
        "  pad_token_id=tokenizer.eos_token_id,\n",
        "  stopping_criteria=StoppingCriteriaList([StopOnTokens()])\n",
        ")\n",
        "\n",
        "# Extract out only the completion tokens\n",
        "completion_tokens = tokens[0][inputs['input_ids'].size(1):]\n",
        "completion = tokenizer.decode(completion_tokens, skip_special_tokens=True)\n",
        "\n",
        "# Display\n",
        "print(user_prompt + \" \", end=\"\")\n",
        "cprint(completion, color=\"green\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rIZm5uwaQLa4"
      },
      "source": [
        "## License (Apache 2.0)\n",
        "\n",
        "Copyright (c) 2023 by [StabilityAI LTD](https://stability.ai/)\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "machine_shape": "hm",
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
